# load all required libraries
library(knitr)
library(plotly)
library(ggplot2)
library(gridExtra)
library(caret)
library(dplyr)
library(corrplot)
library(R2jags)
library(ggmcmc)Introduction and motivation
Coronary Artery Disease (CAD; in Italian “Coronaropatia”), is a type of Cardiovascular Disease (CVD) where the coronary arteries cannot deliver enough oxygen-rich blood to the heart. According to the National Heart, Lung and Blood Institute website “CVDs are the leading cause of death in the United States” (695.000 deaths per year in the US, 1 in every 5 deaths) and “coronary heart diseases are the most common type of CVD” (375.476 deaths per year in the US), with similar proportions around the whole world. In 2021 they were the world’s single biggest killer with 20.5 million deaths globally.
CAD is caused by the build up of plaque made of cholesterol in coronary arteries; this phenomenon is called atherosclerosis. The deposit can partially or totally block the flow of blood in the arteries or ease their blockage by formations of blood clots.
Symptoms of CAD are chest pain, heartburn and shortness of breath but their presence may differ from person to person even if they have the same type of coronary heart disease, but much more commonly it presents no symptoms at all, with the first manifest sign of the disorder being a heart attack, which itself can cause a cardiac arrest, a life-threatening medical emergency if not treated within minutes: 25% of people who have a CAD die suddenly without any previous symptom.
Given the difficulty in detection, the severity of the effects and its widespread diffusion throughout the world’s population a timely and accurate diagnosis of the disease in its early and more asymptomatic phase is extremely important for an early treatment and could save many lives every year.
The proposed study wants to use Bayesian Inference to build a model capable of detecting CAD from non-invasive clinical parameters of patients (as opposed to the coronary angiogram, described below). The Bayesian approach in this case is particularly suitable given the low amount of entries we have in the dataset that has been taken into exam.
Data
Dataset description
The data used for this project come from a publicly available dataset called “Z-Alizadeh Sani”, collected in 2012 for research purposes by Dr. Zahra Alizadeh Sani, associate professor of Cardiology at Iran University, and donated to the UC Irvine Machine Learning Repository (available here) in 2017.
A particularly positive aspect of this dataset is its completeness, there are no missing values.
The dataset contains the records of 303 individuals, random visitors to Shaheed Rajaei Cardiovascular, Medical and Research Center of Tehran. Each sample has 55 features and a target category. All the features that are recorded in the dataset have been chosen by the author due to being considered indicators of CAD according to the current medical literature. The features are arranged in four groups:
- Demographic
- Symptom and examination
- ECG (electrocardiogram)
- Laboratory and echo
The ground truth classification is the result of a Coronary Angiogram made on the patient, an imaging technique used to visualize blood vessels, arteries and veins using X-rays and a radio-opaque contrast agent inserted in the blood flow. While accurate this method is rarely performed due to its high cost and the invasiveness of the procedure.
Each patient is in one of two possible categories: CAD or Normal. A patient is categorized as CAD if the diameter narrowing on an artery is greater than or equal to 50% and is otherwise categorized as Normal. Over the total of 303 individuals 216 samples have CAD and the rest are healthy.
The following table summarize the features of the dataset, their meaning and the values they take:
| Feature Type | Feature Name | Range |
|---|---|---|
| Demographic | Age (years) | 30–86 |
| Weight (kg) | 48–120 | |
| Length (height, cm) | 140–188 | |
| Sex | male, female | |
| BMI (body mass index, kg/m²) | 18–41 | |
| DM (history of diabetes mellitus) | yes, no (binary) | |
| HTN (history of hyper tension) | yes, no (binary) | |
| Current smoker | yes, no (binary) | |
| Ex Smoker | yes, no (binary) | |
| FH (history of CVD in first-degree relatives) | yes, no (binary) | |
| Obesity | yes, no (string) | |
| CRF (chronic renal failure) | yes, no (string) | |
| CVA (cerebrovascular accident) | yes, no (string) | |
| Airway disease | yes, no (string) | |
| Thyroid Disease | yes, no (string) | |
| CHF (congestive heart failure) | yes, no (string) | |
| DLP (dyslipidemia, high lipids in blood) | yes, no (string) | |
| Symptom and Examination | BP (blood pressure, mmHg) | 90–190 |
| PR (pulse rate, ppm) | 50–110 | |
| Edema (fluid retention in body tissue) | yes, no (binary) | |
| Weak peripheral pulse | yes, no (string) | |
| Lung rales (abnormal lung sounds) | yes, no (string) | |
| Systolic murmur | yes, no (string) | |
| Diastolic murmur | yes, no (string) | |
| Typical Chest Pain | yes, no (binary) | |
| Dyspnea (shortness of breath) | yes, no (string) | |
| Function class (frequency of symptoms) | 1, 2, 3, 4 | |
| Atypical | yes, no (string) | |
| Nonanginal CP (chest pain at rest) | yes, no (string) | |
| Exertional CP (chest Pain during physical exertion) | yes, no (string) | |
| Low Th Ang (low threshold angina) | yes, no (string) | |
| ECG (electrocardiogram) | Q Wave | yes, no (binary) |
| ST Elevation | yes, no (binary) | |
| ST Depression | yes, no (binary) | |
| T inversion | yes, no (binary) | |
| LVH (left ventricular hypertrophy) | yes, no (string) | |
| Poor R progression | yes, no (string) | |
| BBB (bundle branch block) | no, left, right | |
| Laboratory and Echo | FBS (fasting blood sugar, mg/dl) | 62–400 |
| Cr (creatine, mg/dl) | 0.5–2.2 | |
| TG (triglyceride, mg/dl) | 37–1050 | |
| LDL (low density lipoprotein, mg/dl) | 18–232 | |
| HDL (high density lipoprotein, mg/dl) | 15–111 | |
| BUN (blood urea nitrogen, mg/dl) | 6–52 | |
| ESR (erythrocyte sedimentation rate, mm/h) | 1–90 | |
| HB (hemoglobin, g/dl) | 8.9–17.6 | |
| K (potassium, mEq/lit) | 3.0–6.6 | |
| Na (sodium, mEq/lit) | 128–156 | |
| WBC (white blood cell, cells/ml) | 3700–18000 | |
| Lymph (lymphocyte, %) | 7–60 | |
| Neut (neutrophil, %) | 32–89 | |
| PLT (platelet, 1000/ml) | 25–742 | |
| EF (ejection fraction, %) | 9–65 | |
| Region with RWMA (regional wall motion abnormality) | 0, 1, 2, 3, 4 | |
| VHD (valvular heart disease) | normal, mild, moderate, severe | |
| Target | Cath (cardiac catheterization) | cad, Normal |
Preprocessing and cleaning
To get an idea of the raw data we provide some example entries in the dataset:
# read CSV file
data = read.csv("Z-Alizadeh Sani dataset.csv")
# print head
knitr::kable(head(data, 5), col.names = gsub("[.]", " ", names(data)))| Age | Weight | Length | Sex | BMI | DM | HTN | Current Smoker | EX Smoker | FH | Obesity | CRF | CVA | Airway disease | Thyroid Disease | CHF | DLP | BP | PR | Edema | Weak Peripheral Pulse | Lung rales | Systolic Murmur | Diastolic Murmur | Typical Chest Pain | Dyspnea | Function Class | Atypical | Nonanginal | Exertional CP | LowTH Ang | Q Wave | St Elevation | St Depression | Tinversion | LVH | Poor R Progression | BBB | FBS | CR | TG | LDL | HDL | BUN | ESR | HB | K | Na | WBC | Lymph | Neut | PLT | EF TTE | Region RWMA | VHD | Cath |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 53 | 90 | 175 | Male | 29.38776 | 0 | 1 | 1 | 0 | 0 | Y | N | N | N | N | N | Y | 110 | 80 | 0 | N | N | N | N | 0 | N | 0 | N | N | N | N | 0 | 0 | 1 | 1 | N | N | N | 90 | 0.7 | 250 | 155 | 30 | 8 | 7 | 15.6 | 4.7 | 141 | 5700 | 39 | 52 | 261 | 50 | 0 | N | Cad |
| 67 | 70 | 157 | Fmale | 28.39872 | 0 | 1 | 0 | 0 | 0 | Y | N | N | N | N | N | N | 140 | 80 | 1 | N | N | N | N | 1 | N | 0 | N | N | N | N | 0 | 0 | 1 | 1 | N | N | N | 80 | 1.0 | 309 | 121 | 36 | 30 | 26 | 13.9 | 4.7 | 156 | 7700 | 38 | 55 | 165 | 40 | 4 | N | Cad |
| 54 | 54 | 164 | Male | 20.07733 | 0 | 0 | 1 | 0 | 0 | N | N | N | N | N | N | N | 100 | 100 | 0 | N | N | N | N | 1 | N | 0 | N | N | N | N | 0 | 0 | 0 | 0 | N | N | N | 85 | 1.0 | 103 | 70 | 45 | 17 | 10 | 13.5 | 4.7 | 139 | 7400 | 38 | 60 | 230 | 40 | 2 | mild | Cad |
| 66 | 67 | 158 | Fmale | 26.83865 | 0 | 1 | 0 | 0 | 0 | Y | N | N | N | N | N | N | 100 | 80 | 0 | N | N | N | Y | 0 | Y | 3 | N | Y | N | N | 0 | 0 | 1 | 0 | N | N | N | 78 | 1.2 | 63 | 55 | 27 | 30 | 76 | 12.1 | 4.4 | 142 | 13000 | 18 | 72 | 742 | 55 | 0 | Severe | Normal |
| 50 | 87 | 153 | Fmale | 37.16519 | 0 | 1 | 0 | 0 | 0 | Y | N | N | N | N | N | N | 110 | 80 | 0 | N | N | Y | N | 0 | Y | 2 | N | N | N | N | 0 | 0 | 0 | 0 | N | N | N | 104 | 1.0 | 170 | 110 | 50 | 16 | 27 | 13.2 | 4.0 | 140 | 9200 | 55 | 39 | 274 | 50 | 0 | Severe | Normal |
Since the features are represented in many different formats the first thing we have to do is to perform an operation of data cleaning and preparation. Binary features in the string form yes/no become booleans and the same happens for the Sex and the target variable Cath.
The feature VHD is populated with 4 different kind of strings, we can easily convert them into the integers 0-3 since they imply sequentiality.
Also the feature BBB contains multiple strings but since there is no sequentiality between the values we have to perform a one-hot encoding conversion, creating more variables.
Moreover we removed the feature Exertional CP since all its entries have the same value. This is how the same rows of before appear after the preprocessing:
| Age | Weight | Length | Sex | BMI | DM | HTN | Current Smoker | EX Smoker | FH | Obesity | CRF | CVA | Airway disease | Thyroid Disease | CHF | DLP | BP | PR | Edema | Weak Peripheral Pulse | Lung rales | Systolic Murmur | Diastolic Murmur | Typical Chest Pain | Dyspnea | Function Class | Atypical | Nonanginal | LowTH Ang | Q Wave | St Elevation | St Depression | Tinversion | LVH | Poor R Progression | FBS | CR | TG | LDL | HDL | BUN | ESR | HB | K | Na | WBC | Lymph | Neut | PLT | EF TTE | Region RWMA | VHD | Cath | LBBB | RBBB |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 53 | 90 | 175 | 1 | 29.38776 | 0 | 1 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 110 | 80 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 0 | 90 | 0.7 | 250 | 155 | 30 | 8 | 7 | 15.6 | 4.7 | 141 | 5700 | 39 | 52 | 261 | 50 | 0 | 0 | Cad | 0 | 0 |
| 67 | 70 | 157 | 0 | 28.39872 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 140 | 80 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 0 | 80 | 1.0 | 309 | 121 | 36 | 30 | 26 | 13.9 | 4.7 | 156 | 7700 | 38 | 55 | 165 | 40 | 4 | 0 | Cad | 0 | 0 |
| 54 | 54 | 164 | 1 | 20.07733 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 100 | 100 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 85 | 1.0 | 103 | 70 | 45 | 17 | 10 | 13.5 | 4.7 | 139 | 7400 | 38 | 60 | 230 | 40 | 2 | 1 | Cad | 0 | 0 |
| 66 | 67 | 158 | 0 | 26.83865 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 100 | 80 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 3 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 78 | 1.2 | 63 | 55 | 27 | 30 | 76 | 12.1 | 4.4 | 142 | 13000 | 18 | 72 | 742 | 55 | 0 | 3 | Normal | 0 | 0 |
| 50 | 87 | 153 | 0 | 37.16519 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 110 | 80 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 104 | 1.0 | 170 | 110 | 50 | 16 | 27 | 13.2 | 4.0 | 140 | 9200 | 55 | 39 | 274 | 50 | 0 | 3 | Normal | 0 | 0 |
Exploratory data analysis
An exploratory data analysis can give us a lot of information on how data behave and what is the difference in distribution between the two target classes that can help us identify them, as well as showing which features would just introduce noise and would be better to leave behind. We can experiment with many different plots, let’s start by plotting the features into histograms:
Despite the low amount of entries we can observe that many continuous variables take approximately the shape of a Normal distribution, while for example EF TTE seem to take a shape more similar to a Beta distribution. The histograms for the discrete variables don’t give us much information in this form, but we can observe a great unbalance in most of the features.
Another interesting observation that can be made with plots is how the behavior of the distributions changes when conditioned on the target class:
From those plots we can infer some interesting patterns: CAD tends to be correlated with a higher age, an increased blood pressure and a higher fasting blood sugar. Males seems to be more affected than females, as well as diabetic people and there are some binary values that take a positive response exclusively in the case of CAD, in particular those inside the electrocardiogram subset of features.
Particularly interesting is the plot of Typical chest pain, where we can see that about 25% of patients with CAD do not experience them, a statistic that fits perfectly with what we found online and reported in the introduction.
In the data we also find some unusual behaviors: for all features that are linked to weight as Weight, BMI, HDL and Obesity there is little to no difference between the distributions for the two classes, while common knowledge says that being overweight leads to a major increase in heart diseases occurrences. We can interpret this observation as a bias of the dataset given by how data has been gathered: since all observations are visitors of a medical facility specialized in cardiovascular diseases the samples recorded are not representative of the overall population. This reasoning has brought to the counterintuitive decision of excluding those kind of variables from the study.
Finally we can analyze the linear correlation between the variables with a correlation plot. Given the high amount of variables it would be messy to create a complete correlation matrix, for this reason we decided to use only the top 10 most correlated features:
Goals and measures
Since the target variable of this study is binary what the project aims to realize is a Bayesian logistic regression. As said before the Bayesian approach is particularly suitable to this problem since it gives the ability to incorporate prior information in our model and is more robust than a frequentist approach in scenarios where few samples are available.
Our aim in terms of model performances is to get the best precision and recall scores as possible, but as for many other medical models, we want to prioritize a high recall because in disease detection models it is vital to identify as many true cases as possible since false positives can be resolved simply with further testing while false negatives may endanger the patient.
To test for performances and prevent the measurements from being positively distorted by overfitting we divide the dataset into two subsets: a train set used for fitting and a test set of never-before-seen data to test the inference capabilities of the model. We use the classic proportion of 75% for train and 25% for test and set the random seed to ensure the replicability of the split.
# set seed for reproducibility
set.seed(123)
train_indices = sample(seq_len(nrow(data)), size = 0.75 * nrow(data))
# split into train and test sets
train_data = data[train_indices, ]
test_data = data[-train_indices, ]Modelling the Bayesian problem
The model we want to build is a logistic regression. In this model we use our set of features \(X = (X_1, ..., X_n)\) to infer the value of a target binary categorical variable:
\[ Y = \left\{ \begin{array}{cl} 1 & \text{CAD} \\ 0 & \text{Normal} \end{array} \right. \]
Differently from a linear regression a logistic regression models as a linear combination of the parameters not a scalar value but the probability of an event measured in the log-odds, which are defined as \(\phi = ln\left( \frac{p}{1-p} \right)\), where \(\frac{p}{1-p}\) is the definition of odds and \(p \in [0,1]\) is the probability of the event. Therefore the regression we are modelling is:
\[ \phi = ln\left( \frac{p}{1-p} \right) = \beta_0 + \beta_1 X_1 + ...+ \beta_n X_n \]
During the inference step we just reverse the formula to obtain the probability:
\[ p = \frac{exp(\phi)}{1 + exp(\phi)} = \frac{1}{1 + exp(- \phi)} \]
And as prediction we just take the class with the highest probability:
\[ \text{pred} = \left\{ \begin{array}{cl} 1 & \text{if } p \ge 0.5 \\ 0 & \text{otherwise} \end{array} \right. \]
The logistic regression is just one of the forms of the Generalized Linear Model, in particular it is what we obtain if the target is distributed as a Bernoulli random variable and the log-odds are what we get as its link function.
Once we have the structure of the model we have to use the training data to infer good values for the \(\beta\) parameters.
Models
1. Naive model (baseline)
Our fist model, that we call “naive”, is a model that takes every feature defined in the dataset and uses them for the logistic regression. We can use this model as a baseline for the others. As priors we decided to use Weakly Informative Priors, a normal distribution with zero mean and high variance.
jags_code = "model{
##########
# PRIORS #
##########
# for the intercept we use a normal with mean 0 and precision 0.01
beta0 ~ dnorm(0, 0.01)
# priors for the coefficients
# for those we use a normal with mean 0 and precision 0.4
for (j in 1:n_features){
beta[j] ~ dnorm(0, 0.4)
}
##############
# LIKELIHOOD #
##############
for (i in 1:n_samples){
# calculate logits (log-odds)
logit_p[i] = beta0 + inprod(beta[1:n_features], x[i,])
# convert log-odds into probabilities
p[i] = 1 / (1 + exp(- logit_p[i]))
# get the binary outcome for the target
y[i] ~ dbern(p[i])
}
}"
features = setdiff(colnames(data), exclude)
# pass parameters to format for JAGS
model_data = list(
x = as.matrix(train_data[, features]),
y = train_data$Cath,
n_samples = nrow(train_data),
n_features = length(features)
)
jags_model1 = jags(model.file=textConnection(jags_code),
data = model_data,
inits = NULL,
n.chains = 5,
n.iter = 15000,
n.burnin = 5000,
parameters.to.save = c("beta0", "beta"))Compiling model graph
Resolving undeclared variables
Allocating nodes
Graph information:
Observed stochastic nodes: 227
Unobserved stochastic nodes: 53
Total graph size: 13680
Initializing model
jags_model1Inference for Bugs model at "4", fit using jags,
5 chains, each with 15000 iterations (first 5000 discarded), n.thin = 10
n.sims = 5000 iterations saved
mu.vect sd.vect 2.5% 25% 50% 75% 97.5% Rhat n.eff
beta[1] 0.126 0.039 0.056 0.099 0.125 0.151 0.207 1.017 210
beta[2] 0.686 0.840 -0.935 0.111 0.659 1.252 2.353 1.003 1100
beta[3] 1.631 0.897 -0.095 1.029 1.629 2.256 3.343 1.001 5000
beta[4] 1.388 0.817 -0.204 0.828 1.386 1.935 2.981 1.005 780
beta[5] 0.408 0.828 -1.179 -0.160 0.399 0.962 2.096 1.002 2100
beta[6] 0.468 1.353 -2.250 -0.437 0.450 1.364 3.168 1.002 3000
beta[7] 1.924 0.872 0.185 1.345 1.918 2.498 3.639 1.003 1400
beta[8] -0.405 0.722 -1.829 -0.895 -0.387 0.076 0.980 1.001 3000
beta[9] 0.102 1.573 -2.994 -0.949 0.105 1.176 3.191 1.001 4900
beta[10] 0.567 1.299 -1.984 -0.306 0.572 1.469 3.110 1.001 5000
beta[11] 1.036 1.293 -1.491 0.146 1.036 1.906 3.524 1.001 5000
beta[12] -0.001 1.298 -2.526 -0.861 -0.003 0.858 2.550 1.001 3500
beta[13] -0.036 1.584 -3.215 -1.104 -0.037 1.003 3.010 1.001 5000
beta[14] 0.044 0.679 -1.255 -0.420 0.041 0.499 1.378 1.001 5000
beta[15] 0.040 0.026 -0.011 0.023 0.039 0.057 0.091 1.008 850
beta[16] 0.088 0.051 -0.013 0.057 0.087 0.121 0.189 1.020 320
beta[17] -0.326 1.242 -2.700 -1.179 -0.334 0.512 2.075 1.002 2500
beta[18] 0.196 1.529 -2.780 -0.813 0.176 1.227 3.244 1.001 5000
beta[19] 1.305 1.325 -1.281 0.410 1.286 2.205 3.908 1.002 2200
beta[20] 0.293 1.074 -1.752 -0.431 0.276 1.013 2.395 1.002 1600
beta[21] -0.366 1.242 -2.789 -1.211 -0.351 0.470 2.045 1.001 5000
beta[22] 3.444 0.896 1.732 2.844 3.444 4.046 5.209 1.001 5000
beta[23] -2.403 0.784 -3.980 -2.921 -2.379 -1.863 -0.924 1.009 350
beta[24] 0.688 0.404 -0.101 0.409 0.689 0.963 1.472 1.005 690
beta[25] -0.471 0.888 -2.229 -1.072 -0.476 0.145 1.283 1.001 5000
beta[26] -1.377 1.133 -3.579 -2.157 -1.377 -0.613 0.859 1.001 5000
beta[27] -0.011 1.595 -3.122 -1.070 -0.024 1.058 3.100 1.001 5000
beta[28] 0.679 1.380 -2.031 -0.248 0.661 1.599 3.392 1.002 2100
beta[29] 1.243 1.332 -1.336 0.380 1.209 2.114 3.929 1.001 5000
beta[30] 1.381 0.926 -0.368 0.729 1.361 2.012 3.201 1.006 580
beta[31] 1.984 0.846 0.329 1.428 1.982 2.547 3.658 1.001 4100
beta[32] 0.527 1.098 -1.640 -0.208 0.523 1.273 2.689 1.002 2500
beta[33] 0.254 1.489 -2.623 -0.756 0.241 1.275 3.174 1.001 5000
beta[34] 0.003 0.009 -0.015 -0.003 0.003 0.009 0.021 1.001 5000
beta[35] 0.548 1.179 -1.784 -0.237 0.553 1.312 2.894 1.002 2400
beta[36] 0.013 0.006 0.003 0.009 0.013 0.017 0.025 1.003 1300
beta[37] 0.002 0.012 -0.020 -0.006 0.002 0.010 0.026 1.005 640
beta[38] 0.011 0.028 -0.042 -0.007 0.011 0.030 0.068 1.008 390
beta[39] -0.036 0.064 -0.165 -0.079 -0.035 0.007 0.089 1.006 540
beta[40] 0.003 0.029 -0.054 -0.016 0.003 0.023 0.061 1.001 4200
beta[41] -0.387 0.305 -1.035 -0.585 -0.364 -0.171 0.148 1.050 74
beta[42] 1.355 0.827 -0.239 0.792 1.323 1.927 3.006 1.019 220
beta[43] -0.120 0.066 -0.273 -0.155 -0.116 -0.073 -0.012 1.219 21
beta[44] 0.000 0.000 0.000 0.000 0.000 0.000 0.001 1.004 990
beta[45] -0.068 0.086 -0.229 -0.127 -0.066 -0.016 0.120 1.030 130
beta[46] -0.050 0.082 -0.193 -0.110 -0.050 0.005 0.128 1.043 110
beta[47] -0.005 0.007 -0.020 -0.010 -0.005 -0.001 0.008 1.011 290
beta[48] -0.052 0.052 -0.156 -0.087 -0.051 -0.016 0.049 1.010 310
beta[49] 2.753 0.885 1.194 2.134 2.700 3.313 4.641 1.001 3900
beta[50] -0.724 0.605 -1.941 -1.119 -0.709 -0.320 0.459 1.001 3900
beta[51] -0.984 1.211 -3.389 -1.804 -0.954 -0.172 1.370 1.002 1900
beta[52] -0.231 1.198 -2.607 -1.048 -0.218 0.577 2.092 1.002 1800
beta0 0.566 7.456 -15.661 -4.146 0.465 5.425 14.725 1.086 42
deviance 103.363 8.396 88.288 97.459 102.775 108.812 120.908 1.007 460
For each parameter, n.eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor (at convergence, Rhat=1).
DIC info (using the rule, pD = var(deviance)/2)
pD = 35.0 and DIC = 138.3
DIC is an estimate of expected predictive error (lower deviance is better).
# posterior samples of coefficients
beta_samples = as.data.frame(ggs(as.mcmc(jags_model1)))
# calculate the mean for each parameter (beta)
beta_means = beta_samples %>%
group_by(Parameter) %>%
summarize(mean_value = mean(value)) %>%
spread(key = Parameter, value = mean_value)
# prepare the test data with all predictors
model_test_data = as.matrix(test_data[, features])
# calculate predicted probabilities using posterior means
logit_prediction = beta_means$beta0 +
model_test_data %*% as.numeric(beta_means[grep("beta(?!0)", names(beta_means), perl = TRUE)])
# convert log-odds to probabilities
pred_probs = 1 / (1 + exp(- logit_prediction))
# convert probabilities to class predictions
predictions = ifelse(pred_probs >= 0.5, 1, 0)
# calculate metrics
accuracy = sum(predictions == test_data$Cath) / nrow(test_data)
recall = sum(predictions == 1 & test_data$Cath == 1) / sum(test_data$Cath == 1)
precision = sum(predictions == 1 & test_data$Cath == 1) / sum(predictions == 1)
f1_score = 2 * (precision * recall) / (precision + recall)
# print metrics
print(paste("Accuracy: ", accuracy))[1] "Accuracy: 0.842105263157895"
print(paste("Recall: ", recall))[1] "Recall: 0.924528301886792"
print(paste("Precision: ", precision))[1] "Precision: 0.859649122807018"
print(paste("F1 Score: ", f1_score))[1] "F1 Score: 0.890909090909091"
2. Model with feature selection
The second model is very similar to the naive one but some features that have been noticed to introduce noise are removed. What we obtain should be a slightly simpler model and this should help convergence.
exclude2 = c("Cath", "Weight", "Length", "BMI", "Atypical", "Nonanginal", "FBS", "Diastolic.Murmur", "Current.Smoker", "EX.Smoker")
features = setdiff(colnames(data), exclude2)
# pass parameters to format for JAGS
model_data = list(
x = as.matrix(train_data[, features]),
y = train_data$Cath,
n_samples = nrow(train_data),
n_features = length(features)
)
jags_model2 = jags(model.file=textConnection(jags_code),
data = model_data,
inits = NULL,
n.chains = 5,
n.iter = 15000,
n.burnin = 5000,
parameters.to.save = c("beta0", "beta"))Compiling model graph
Resolving undeclared variables
Allocating nodes
Graph information:
Observed stochastic nodes: 227
Unobserved stochastic nodes: 47
Total graph size: 12312
Initializing model
jags_model2Inference for Bugs model at "5", fit using jags,
5 chains, each with 15000 iterations (first 5000 discarded), n.thin = 10
n.sims = 5000 iterations saved
mu.vect sd.vect 2.5% 25% 50% 75% 97.5% Rhat n.eff
beta[1] 0.119 0.039 0.047 0.091 0.117 0.144 0.201 1.007 1100
beta[2] 0.849 0.813 -0.705 0.302 0.827 1.390 2.501 1.011 280
beta[3] 1.865 0.777 0.382 1.341 1.852 2.386 3.409 1.001 3800
beta[4] 1.212 0.774 -0.289 0.686 1.218 1.717 2.726 1.006 540
beta[5] 1.844 0.853 0.204 1.261 1.840 2.424 3.532 1.004 900
beta[6] -0.438 0.715 -1.839 -0.922 -0.435 0.040 0.976 1.001 3500
beta[7] 0.137 1.539 -2.840 -0.913 0.158 1.151 3.187 1.002 2000
beta[8] 0.611 1.267 -1.846 -0.258 0.612 1.442 3.133 1.002 2300
beta[9] 1.024 1.268 -1.424 0.164 1.018 1.876 3.497 1.002 2600
beta[10] -0.071 1.252 -2.472 -0.936 -0.077 0.768 2.432 1.001 5000
beta[11] -0.003 1.556 -3.041 -1.033 0.002 1.022 3.035 1.001 5000
beta[12] 0.038 0.662 -1.255 -0.415 0.040 0.487 1.335 1.004 940
beta[13] 0.038 0.023 -0.006 0.022 0.037 0.053 0.084 1.024 140
beta[14] 0.086 0.051 -0.016 0.052 0.087 0.118 0.190 1.021 230
beta[15] -0.242 1.204 -2.622 -1.052 -0.234 0.572 2.121 1.002 2000
beta[16] 0.167 1.484 -2.686 -0.854 0.173 1.167 3.135 1.002 2100
beta[17] 1.180 1.298 -1.364 0.290 1.167 2.046 3.745 1.001 4600
beta[18] 0.197 1.029 -1.821 -0.503 0.202 0.889 2.213 1.005 670
beta[19] 3.859 0.674 2.583 3.390 3.850 4.304 5.225 1.004 940
beta[20] -2.313 0.766 -3.851 -2.837 -2.313 -1.792 -0.840 1.012 270
beta[21] 0.670 0.384 -0.087 0.412 0.669 0.921 1.443 1.002 2700
beta[22] -0.009 1.576 -3.069 -1.062 -0.022 1.061 3.104 1.002 1900
beta[23] 0.701 1.380 -1.961 -0.250 0.665 1.578 3.547 1.001 3100
beta[24] 1.246 1.325 -1.337 0.348 1.238 2.119 3.902 1.002 1700
beta[25] 1.295 0.922 -0.554 0.676 1.300 1.909 3.080 1.014 230
beta[26] 1.947 0.833 0.363 1.407 1.934 2.489 3.623 1.004 950
beta[27] 0.784 1.102 -1.363 0.031 0.780 1.527 2.908 1.005 660
beta[28] 0.338 1.471 -2.474 -0.671 0.337 1.327 3.276 1.003 1300
beta[29] 0.622 1.130 -1.625 -0.114 0.634 1.367 2.818 1.007 540
beta[30] 0.012 0.005 0.003 0.009 0.012 0.016 0.023 1.009 370
beta[31] 0.001 0.011 -0.021 -0.007 0.001 0.008 0.023 1.011 290
beta[32] 0.013 0.028 -0.042 -0.006 0.012 0.031 0.069 1.013 250
beta[33] -0.029 0.060 -0.147 -0.069 -0.028 0.011 0.087 1.008 400
beta[34] 0.003 0.029 -0.053 -0.016 0.003 0.023 0.063 1.016 200
beta[35] -0.373 0.315 -1.032 -0.588 -0.356 -0.151 0.223 1.085 44
beta[36] 1.496 0.816 -0.108 0.924 1.498 2.088 3.063 1.045 72
beta[37] -0.100 0.091 -0.311 -0.159 -0.097 -0.032 0.059 1.251 19
beta[38] 0.000 0.000 0.000 0.000 0.000 0.000 0.001 1.011 320
beta[39] -0.038 0.092 -0.223 -0.103 -0.039 0.028 0.134 1.015 210
beta[40] -0.020 0.090 -0.197 -0.084 -0.017 0.048 0.143 1.023 150
beta[41] -0.007 0.007 -0.020 -0.011 -0.006 -0.002 0.006 1.022 150
beta[42] -0.034 0.055 -0.145 -0.071 -0.033 0.005 0.072 1.030 120
beta[43] 2.709 0.863 1.188 2.111 2.663 3.240 4.545 1.004 900
beta[44] -0.613 0.590 -1.777 -1.000 -0.603 -0.213 0.525 1.003 1400
beta[45] -0.963 1.181 -3.328 -1.747 -0.947 -0.163 1.313 1.002 2500
beta[46] -0.163 1.152 -2.458 -0.923 -0.155 0.597 2.071 1.002 2500
beta0 -5.867 7.724 -21.395 -11.500 -5.116 -0.111 8.622 1.139 27
deviance 101.437 8.071 87.288 95.677 100.966 106.641 118.863 1.014 220
For each parameter, n.eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor (at convergence, Rhat=1).
DIC info (using the rule, pD = var(deviance)/2)
pD = 32.0 and DIC = 133.4
DIC is an estimate of expected predictive error (lower deviance is better).
# posterior samples of coefficients
beta_samples = as.data.frame(ggs(as.mcmc(jags_model2)))
# calculate the mean for each parameter (beta)
beta_means = beta_samples %>%
group_by(Parameter) %>%
summarize(mean_value = mean(value)) %>%
spread(key = Parameter, value = mean_value)
# prepare the test data with all predictors
model_test_data = as.matrix(test_data[, features])
# calculate predicted probabilities using posterior means
logit_prediction = beta_means$beta0 +
model_test_data %*% as.numeric(beta_means[grep("beta(?!0)", names(beta_means), perl = TRUE)])
# convert log-odds to probabilities
pred_probs = 1 / (1 + exp(- logit_prediction))
# convert probabilities to class predictions
predictions = ifelse(pred_probs >= 0.5, 1, 0)
# calculate metrics
accuracy = sum(predictions == test_data$Cath) / nrow(test_data)
recall = sum(predictions == 1 & test_data$Cath == 1) / sum(test_data$Cath == 1)
precision = sum(predictions == 1 & test_data$Cath == 1) / sum(predictions == 1)
f1_score = 2 * (precision * recall) / (precision + recall)
# print metrics
print(paste("Accuracy: ", accuracy))[1] "Accuracy: 0.842105263157895"
print(paste("Recall: ", recall))[1] "Recall: 0.924528301886792"
print(paste("Precision: ", precision))[1] "Precision: 0.859649122807018"
print(paste("F1 Score: ", f1_score))[1] "F1 Score: 0.890909090909091"
3. Model with feature engineering
This third model takes inspiration from what has been used in the original Dr. Alizadeh’s paper about this dataset. A discretization of some variables has been executed differencing levels in “Low”, “Normal” and “High”. In particular some features have differences depending on the sex of the patient.
| Feature | Low | Normal | High |
|---|---|---|---|
| Cr | Cr < 0.7 | 0.7 ≤ Cr ≤ 1.5 | Cr > 1.5 |
| FBS | FBS < 70 | 70 ≤ FBS ≤ 105 | FBS > 105 |
| LDL | LDL < 130 | LDL > 130 | |
| HDL | HDL < 35 | HDL ≥ 35 | |
| BUN | BUN < 7 | 7 ≤ BUN ≤ 20 | BUN > 20 |
| ESR | If male and ESR ≤ age/2, if female and ESR ≤ age/2 + 5 | If male and ESR > age/2, if female and ESR > age/2 + 5 | |
| HB | If male and HB < 14, if female and HB < 12.5 | If male and 14 ≤ HB ≤ 17, if female and 12.5 ≤ HB ≤ 15 | If male and HB > 17, if female and HB > 15 |
| K | K < 3.8 | 3.8 ≤ K ≤ 5.6 | K > 5.6 |
| Na | Na < 136 | 136 ≤ Na ≤ 146 | Na > 146 |
| WBC | WBC < 4000 | 4000 ≤ WBC ≤ 11,000 | WBC > 11,000 |
| PLT | PLT < 150 | 150 ≤ PLT ≤ 450 | PLT > 450 |
| EF | EF ≤ 50 | EF > 50 | |
| Region with RWMA | Region with RWMA = 0 | Region with RWMA ≠ 0 | |
| Age | If male and age ≤ 45, if female and age ≤ 55 | If male and age > 45, if female and age > 55 | |
| BP | BP < 90 | 90 ≤ BP ≤ 140 | BP > 140 |
| PulseRate | PulseRate < 60 | 60 ≤ PulseRate ≤ 100 | PulseRate > 100 |
| TG | TG < 200 | TG ≥ 200 |
# discretive features
data$CR = as.numeric(cut(data$CR, breaks=c(-Inf, 0.7, 1.5, Inf), labels=c(0, 1, 2)))
data$FBS = as.numeric(cut(data$FBS, breaks=c(-Inf, 70, 105, Inf), labels=c(0, 1, 2)))
data$LDL = as.numeric(cut(data$LDL, breaks=c(-Inf, 130, Inf), labels=c(1, 2)))
data$HDL = as.numeric(cut(data$HDL, breaks=c(-Inf, 35, Inf), labels=c(0, 1)))
data$BUN = as.numeric(cut(data$BUN, breaks=c(-Inf, 7, 20, Inf), labels=c(0, 1, 2)))
data$K = as.numeric(cut(data$K, breaks=c(-Inf, 3.8, 5.6, Inf), labels=c(0, 1, 2)))
data$Na = as.numeric(cut(data$Na, breaks=c(-Inf, 136, 146, Inf), labels=c(0, 1, 2)))
data$WBC = as.numeric(cut(data$WBC, breaks=c(-Inf, 4000, 11000, Inf), labels=c(0, 1, 2)))
data$PLT = as.numeric(cut(data$PLT, breaks=c(-Inf, 150, 450, Inf), labels=c(0, 1, 2)))
data$EF.TTE = as.numeric(cut(data$EF.TTE, breaks=c(-Inf, 50, Inf), labels=c(0, 1)))
data$BP = as.numeric(cut(data$BP, breaks=c(-Inf, 90, 140, Inf), labels=c(0, 1, 2)))
data$PR = as.numeric(cut(data$PR, breaks=c(-Inf, 60, 100, Inf), labels=c(0, 1, 2)))
data$TG = as.numeric(cut(data$TG, breaks=c(-Inf, 200, Inf), labels=c(1, 2)))
data$Function.Class = as.numeric(cut(data$Function.Class, breaks=c(-Inf, 1.5, Inf), labels=c(1, 2)))
data$Region.RWMA = as.numeric(ifelse(data$Region.RWMA == 0, 1, 2))
data$ESR = as.numeric(with(data, ifelse((Sex == 1 & ESR <= Age/2) | (Sex == 0 & ESR <= Age/2 + 5), 1, 2)))
data$HB = as.numeric(with(data, ifelse((Sex == 1 & HB < 14) | (Sex == 0 & HB < 12.5), 0, ifelse((Sex == 1 & HB <= 17) | (Sex == 0 & HB <= 15), 1, 2))))
data$Age = as.numeric(with(data, ifelse((Sex == 1 & Age > 45) | (Sex == 0 & Age > 55), 2, 1)))
# update train and test datasets
train_data_discrete = data[train_indices, ]
test_data_discrete = data[-train_indices, ]
features = setdiff(colnames(data), exclude)
# pass parameters to format for JAGS
model_data = list(
x = as.matrix(train_data_discrete[, features]),
y = train_data_discrete$Cath,
n_samples = nrow(train_data_discrete),
n_features = length(features)
)
jags_model3 = jags(model.file=textConnection(jags_code),
data = model_data,
inits = NULL,
n.chains = 5,
n.iter = 15000,
n.burnin = 5000,
parameters.to.save = c("beta0", "beta"))Compiling model graph
Resolving undeclared variables
Allocating nodes
Graph information:
Observed stochastic nodes: 227
Unobserved stochastic nodes: 53
Total graph size: 13680
Initializing model
jags_model3Inference for Bugs model at "6", fit using jags,
5 chains, each with 15000 iterations (first 5000 discarded), n.thin = 10
n.sims = 5000 iterations saved
mu.vect sd.vect 2.5% 25% 50% 75% 97.5% Rhat n.eff
beta[1] 2.252 0.726 0.881 1.760 2.247 2.733 3.726 1.005 660
beta[2] 0.010 0.668 -1.295 -0.444 0.010 0.458 1.310 1.001 5000
beta[3] 2.192 0.786 0.656 1.654 2.195 2.733 3.727 1.002 2000
beta[4] 1.949 0.650 0.682 1.517 1.934 2.379 3.242 1.002 2600
beta[5] -0.034 0.769 -1.557 -0.555 -0.031 0.479 1.463 1.001 5000
beta[6] 1.010 1.241 -1.365 0.169 1.007 1.838 3.510 1.001 4200
beta[7] 1.493 0.832 -0.073 0.931 1.491 2.030 3.154 1.001 3600
beta[8] -0.990 0.664 -2.310 -1.431 -0.985 -0.557 0.338 1.001 5000
beta[9] 0.104 1.519 -2.720 -0.925 0.059 1.128 3.042 1.001 3100
beta[10] -0.195 1.268 -2.640 -1.036 -0.221 0.679 2.323 1.001 5000
beta[11] 1.172 1.186 -1.104 0.366 1.203 1.983 3.482 1.001 3300
beta[12] -0.539 1.208 -2.864 -1.357 -0.552 0.263 1.884 1.001 3900
beta[13] 0.233 1.499 -2.677 -0.779 0.222 1.249 3.144 1.001 5000
beta[14] -0.466 0.605 -1.632 -0.867 -0.465 -0.071 0.729 1.003 1100
beta[15] -0.526 0.849 -2.220 -1.108 -0.508 0.071 1.071 1.004 1000
beta[16] 1.600 1.083 -0.376 0.855 1.543 2.271 4.008 1.004 1600
beta[17] 0.346 1.099 -1.748 -0.409 0.345 1.075 2.551 1.001 5000
beta[18] 0.569 1.419 -2.143 -0.398 0.538 1.489 3.422 1.003 1500
beta[19] 1.165 1.194 -1.129 0.346 1.172 1.986 3.476 1.001 5000
beta[20] 1.250 0.894 -0.474 0.656 1.255 1.826 3.042 1.001 5000
beta[21] -0.611 1.183 -2.991 -1.407 -0.573 0.187 1.589 1.001 4500
beta[22] 3.495 0.833 1.921 2.923 3.473 4.041 5.169 1.003 1400
beta[23] -1.562 0.659 -2.867 -2.010 -1.565 -1.108 -0.281 1.001 5000
beta[24] 0.818 0.714 -0.555 0.332 0.806 1.295 2.227 1.003 1200
beta[25] -0.394 0.796 -1.962 -0.935 -0.397 0.155 1.157 1.001 3100
beta[26] -1.246 0.998 -3.205 -1.917 -1.245 -0.593 0.728 1.002 2900
beta[27] 0.022 1.549 -3.015 -1.026 0.034 1.074 3.069 1.001 5000
beta[28] 1.263 1.301 -1.186 0.385 1.241 2.090 3.920 1.001 3600
beta[29] 1.123 1.289 -1.356 0.232 1.111 1.985 3.719 1.001 3200
beta[30] 1.289 0.843 -0.337 0.719 1.287 1.859 2.959 1.004 910
beta[31] 1.790 0.759 0.322 1.285 1.763 2.292 3.280 1.002 2900
beta[32] -0.015 0.987 -1.836 -0.684 -0.035 0.625 1.957 1.001 3300
beta[33] 0.357 1.444 -2.416 -0.604 0.320 1.304 3.282 1.002 1800
beta[34] 0.274 0.664 -1.001 -0.178 0.270 0.716 1.610 1.007 490
beta[35] -0.963 0.872 -2.686 -1.567 -0.966 -0.352 0.709 1.005 740
beta[36] 0.707 0.759 -0.746 0.189 0.713 1.212 2.183 1.004 790
beta[37] 0.697 0.729 -0.710 0.213 0.689 1.187 2.154 1.003 1400
beta[38] 0.103 0.652 -1.193 -0.328 0.102 0.548 1.364 1.003 1500
beta[39] 0.116 0.664 -1.169 -0.333 0.115 0.553 1.430 1.004 1100
beta[40] -0.653 0.764 -2.141 -1.150 -0.659 -0.139 0.861 1.001 3300
beta[41] -0.103 0.575 -1.239 -0.482 -0.110 0.281 1.023 1.001 5000
beta[42] -0.176 0.720 -1.561 -0.679 -0.177 0.311 1.252 1.001 5000
beta[43] -0.909 0.941 -2.778 -1.537 -0.913 -0.284 0.997 1.023 140
beta[44] 1.161 1.102 -0.856 0.404 1.132 1.884 3.372 1.019 170
beta[45] -0.043 0.074 -0.190 -0.093 -0.041 0.004 0.108 1.028 110
beta[46] -0.010 0.073 -0.153 -0.060 -0.007 0.035 0.140 1.031 100
beta[47] 0.130 1.069 -2.013 -0.579 0.108 0.861 2.204 1.006 580
beta[48] -1.454 0.653 -2.749 -1.901 -1.441 -1.004 -0.189 1.002 1500
beta[49] 3.268 0.936 1.489 2.629 3.243 3.890 5.209 1.004 820
beta[50] -0.927 0.531 -1.994 -1.283 -0.914 -0.556 0.074 1.002 2200
beta[51] -0.151 1.109 -2.348 -0.906 -0.139 0.595 2.022 1.002 2200
beta[52] 0.071 1.168 -2.212 -0.692 0.080 0.855 2.372 1.001 5000
beta0 -7.353 6.605 -21.191 -11.398 -6.816 -2.952 4.771 1.040 100
deviance 100.803 8.419 85.863 94.886 100.395 106.189 118.479 1.001 5000
For each parameter, n.eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor (at convergence, Rhat=1).
DIC info (using the rule, pD = var(deviance)/2)
pD = 35.5 and DIC = 136.3
DIC is an estimate of expected predictive error (lower deviance is better).
# posterior samples of coefficients
beta_samples = as.data.frame(ggs(as.mcmc(jags_model3)))
# calculate the mean for each parameter (beta)
beta_means = beta_samples %>%
group_by(Parameter) %>%
summarize(mean_value = mean(value)) %>%
spread(key = Parameter, value = mean_value)
# prepare the test data with all predictors
model_test_data = as.matrix(test_data_discrete[, features])
# calculate predicted probabilities using posterior means
logit_prediction = beta_means$beta0 +
model_test_data %*% as.numeric(beta_means[grep("beta(?!0)", names(beta_means), perl = TRUE)])
# convert log-odds to probabilities
pred_probs = 1 / (1 + exp(- logit_prediction))
# convert probabilities to class predictions
predictions = ifelse(pred_probs >= 0.5, 1, 0)
# calculate metrics
accuracy = sum(predictions == test_data_discrete$Cath) / nrow(test_data_discrete)
recall = sum(predictions == 1 & test_data_discrete$Cath == 1) / sum(test_data_discrete$Cath == 1)
precision = sum(predictions == 1 & test_data_discrete$Cath == 1) / sum(predictions == 1)
f1_score = 2 * (precision * recall) / (precision + recall)
# print metrics
print(paste("Accuracy: ", accuracy))[1] "Accuracy: 0.868421052631579"
print(paste("Recall: ", recall))[1] "Recall: 0.962264150943396"
print(paste("Precision: ", precision))[1] "Precision: 0.864406779661017"
print(paste("F1 Score: ", f1_score))[1] "F1 Score: 0.910714285714286"
4. Model with “extreme” feature selection and feature engineering
The final model we propose is much more extreme than the first three: we keep the data discretization done before but perform a very strong feature selection operation, keeping only those features that have a reasonably high Pearson correlation score with the target.
# define threshold
threshold = 0.2
# get correlations
correlations = sapply(names(data)[names(data) != "Cath"], function(x){
cor(data[[x]], data[["Cath"]])
})
correlations = data.frame(Feature = names(correlations), Correlation = correlations)
# select features with correlation above the threshold
features = (correlations %>% filter(abs(Correlation) > threshold))$Feature
features [1] "Age" "DM" "HTN"
[4] "Typical.Chest.Pain" "Atypical" "Nonanginal"
[7] "Tinversion" "FBS" "EF.TTE"
[10] "Region.RWMA"
Setting the threshold to 0.2 we get only 10 features. With respect to the previous models we can expect lower performances but much better convergence scores given the lower complexity of the model.
# pass parameters to format for JAGS
model_data = list(
x = as.matrix(train_data_discrete[, features]),
y = train_data_discrete$Cath,
n_samples = nrow(train_data_discrete),
n_features = length(features)
)
jags_model4 = jags(model.file=textConnection(jags_code),
data = model_data,
inits = NULL,
n.chains = 5,
n.iter = 15000,
n.burnin = 5000,
parameters.to.save = c("beta0", "beta"))Compiling model graph
Resolving undeclared variables
Allocating nodes
Graph information:
Observed stochastic nodes: 227
Unobserved stochastic nodes: 11
Total graph size: 3438
Initializing model
jags_model4Inference for Bugs model at "7", fit using jags,
5 chains, each with 15000 iterations (first 5000 discarded), n.thin = 10
n.sims = 5000 iterations saved
mu.vect sd.vect 2.5% 25% 50% 75% 97.5% Rhat n.eff
beta[1] 1.405 0.510 0.424 1.058 1.397 1.737 2.436 1.005 620
beta[2] 1.365 0.601 0.188 0.950 1.368 1.771 2.558 1.004 940
beta[3] 1.275 0.454 0.422 0.962 1.268 1.568 2.180 1.001 5000
beta[4] 2.927 0.636 1.703 2.506 2.918 3.333 4.203 1.001 5000
beta[5] 0.026 0.604 -1.159 -0.387 0.023 0.432 1.216 1.001 4000
beta[6] -1.116 0.814 -2.778 -1.646 -1.112 -0.560 0.440 1.001 4300
beta[7] 1.446 0.571 0.358 1.058 1.434 1.822 2.599 1.001 3400
beta[8] 0.475 0.512 -0.542 0.124 0.481 0.821 1.482 1.004 880
beta[9] -0.897 0.447 -1.766 -1.196 -0.892 -0.596 -0.020 1.002 2400
beta[10] 3.103 0.775 1.650 2.563 3.082 3.606 4.683 1.002 1700
beta0 -7.265 1.933 -11.280 -8.477 -7.205 -5.956 -3.599 1.003 1300
deviance 120.630 4.721 113.145 117.277 119.947 123.426 131.114 1.002 1700
For each parameter, n.eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor (at convergence, Rhat=1).
DIC info (using the rule, pD = var(deviance)/2)
pD = 11.1 and DIC = 131.8
DIC is an estimate of expected predictive error (lower deviance is better).
# posterior samples of coefficients
beta_samples = as.data.frame(ggs(as.mcmc(jags_model4)))
# calculate the mean for each parameter (beta)
beta_means = beta_samples %>%
group_by(Parameter) %>%
summarize(mean_value = mean(value)) %>%
spread(key = Parameter, value = mean_value)
# prepare the test data with all predictors
model_test_data = as.matrix(test_data_discrete[, features])
# calculate predicted probabilities using posterior means
logit_prediction = beta_means$beta0 +
model_test_data %*% as.numeric(beta_means[grep("beta(?!0)", names(beta_means), perl = TRUE)])
# convert log-odds to probabilities
pred_probs = 1 / (1 + exp(- logit_prediction))
# convert probabilities to class predictions
predictions = ifelse(pred_probs >= 0.5, 1, 0)
# calculate metrics
accuracy = sum(predictions == test_data_discrete$Cath) / nrow(test_data_discrete)
recall = sum(predictions == 1 & test_data_discrete$Cath == 1) / sum(test_data_discrete$Cath == 1)
precision = sum(predictions == 1 & test_data_discrete$Cath == 1) / sum(predictions == 1)
f1_score = 2 * (precision * recall) / (precision + recall)
# print metrics
print(paste("Accuracy: ", accuracy))[1] "Accuracy: 0.868421052631579"
print(paste("Recall: ", recall))[1] "Recall: 0.943396226415094"
print(paste("Precision: ", precision))[1] "Precision: 0.87719298245614"
print(paste("F1 Score: ", f1_score))[1] "F1 Score: 0.909090909090909"
Models comparison
One of the main criteria used to carry out model selection between Bayesian models obtained with Monte Carlo Markov Chains is the Deviance Information Criterion (DIC). Under the assumption of a multivariate normal distribution of the parameters this criterion measures a score (the lower the better) that favors goodness of fit on the data and penalizes the complexity of the model.
Confronting the four models we can see the first two getting the same performances in terms of metrics, with the second having a lower DIC. We can see however that both models have some problems with convergence, as can be observed from some Rhat values being \(>1.1\) despite the high amount of iterations.
The model that uses feature discretization has better performances in all metrics and manages to reach a lower DIC than the first model, while having the same amount of features.
The last model, despite the tremendous decrease in the amount of features, manages to get better performances in all metrics when compared to the first two and has about the same F1-score of the third, with a slightly lower recall and slightly higher precision, and has the lowest DIC of all models.
In light of those observations the best models between the four proposed are the third and the fourth and in the next part of the report we will focus on those models to study convergence.
MCMC convergence diagnostics
What we are doing with JAGS is estimating the posterior distribution of the parameters from their priors and the data using Monte Carlo Markov Chains (MCMCs). If the MCMC has not converged the sampling will be biased, leading to inaccurate predictions. Therefore when using JAGS is critically important to check the convergence of the MCMCs.
There are many different diagnostics for the convergence of MCMCs. To have an idea of the various diagnostics possible we explored the R library ggmcmc, which contains various tools for assessing and diagnosing convergence of MCMCs and decided to use Gelman-Rubin statistic (R-hat), Effective sample size (ESS) (which are also included in the report at the end of the JAGS execution), Geweke’s diagnostic, Autocorrelation, Trace plots, Density plots.
Geweke’s diagnostic
Diagnostic for the convergence of MCMCs proposed by Geweke in 1992. It is an hypothesis test that has as null hypothesis that the Markov chain is in the stationary distribution. It is based on a test for equality of the means of the first and last part of the Markov chain. The reported value is the Z-score: the difference between the two sample means divided by its estimated standard error.
Plot for model 3:
Plot for model 4:
Gelman-Rubin statistic (R-hat)
Gelman-Rubin statistic, also known Potential Scale Reduction Factor or just R-hat statistic is a statistic to assess convergence of Monte Carlo Markov Chains proposed by Gelman and Rubin in 1992. Its value is:
\[ \hat{R} = \frac{\frac{n-1}{n}W + \frac{1}{n}B}{W} \]
Where \(n\) is the length of each chain, \(B\) is the variance between the mean of the chains and \(W\) the mean variance inside each chain. The usual threshold to assess convergence is \(\hat{R} < 1.1\).
Plot for model 3:
Plot for model 4:
As said before the first two models have some Rhat values that are above the threshold, thus leading us to conclude that they had some convergence issue. Meanwhile both the third and the fourth model have values below the threshold and in particular the fourth has all values well below it.
Effective sample size (ESS) and autocorrelation
Inside the same chain samples tend to be autocorrelated. The effective sample size is an estimate of the sample size required to achieve the same level of precision if that sample was a simple random sample.
The plots show the lag-k autocorrelation, the correlation between a sample and the sample k steps before. This value should become smaller as k increases and indicates that samples can be considered independent.
The ESS is already reported in the JAGS execution while we insert below the autocorrelation plots:
Plot for model 3:
Plot for model 4:
We can see from both the autocorrelation plots and the low effective sample size reported that for the third model the parameters beta[45], beta[46] and the intercept beta0 present some issues, with their curve still being significantly different from zero even for a large k, even if it seems continuing going down.
Meanwhile the fourth model seem present no issues at all for this diagnostic.
Trace plots
Trace plots show the behavior of each chain for each parameter over the iterations. In the trace plots, we want to avoid flat parts where the chain stays in the same state for too long or too many consecutive steps in the same direction.
Plot for model 3:
Plot for model 4:
This diagnostic confirms what we have already seen: for the third model some values seems to have not yet converged; critical behaviors are present in the intercept beta0 but also some coefficients, more than we have already diagnosed with the previous tools.
Density plots
They are the density plots of the posterior parameters distributions. Since we have multiple chains there are multiple plots superimposed. The similarity of distribution of different chains for the same parameter is a good symptom of convergence.
Plot for model 3:
Plot for model 4:
Also this diagnostic confirms what we have said above: for the third model the intercept and some coefficients have distribution that differs greatly between chains signifying issues in convergence, while in the fourth model we can observe a much more uniform behavior.
Comparative analysis with frequentist inference
As suggested by the Final Project Guidelines we want to perform a confrontation between our models and what we could obtain with a frequentist model. This can be of interest to have empirical proof of the effectiveness of the Bayesian approach.
features = colnames(data)[colnames(data) != "Cath"]
model_data = train_data[, features]
# fit the logistic regression
logistic_model = glm(Cath ~ Age + Weight + Length + Sex + BMI + DM + HTN +
Current.Smoker + EX.Smoker + FH + Obesity + CRF +
CVA + Airway.disease + Thyroid.Disease + CHF + DLP +
BP + PR + Edema + Weak.Peripheral.Pulse + Lung.rales +
Systolic.Murmur + Diastolic.Murmur + Typical.Chest.Pain +
Dyspnea + Function.Class + Atypical + Nonanginal +
LowTH.Ang + Q.Wave + St.Elevation + St.Depression +
Tinversion + LVH + Poor.R.Progression + FBS + CR +
TG + LDL + HDL + BUN + ESR + HB + K + Na + WBC +
Lymph + Neut + PLT + EF.TTE + Region.RWMA + VHD +
LBBB + RBBB,
data = train_data)
summary(logistic_model)
Call:
glm(formula = Cath ~ Age + Weight + Length + Sex + BMI + DM +
HTN + Current.Smoker + EX.Smoker + FH + Obesity + CRF + CVA +
Airway.disease + Thyroid.Disease + CHF + DLP + BP + PR +
Edema + Weak.Peripheral.Pulse + Lung.rales + Systolic.Murmur +
Diastolic.Murmur + Typical.Chest.Pain + Dyspnea + Function.Class +
Atypical + Nonanginal + LowTH.Ang + Q.Wave + St.Elevation +
St.Depression + Tinversion + LVH + Poor.R.Progression + FBS +
CR + TG + LDL + HDL + BUN + ESR + HB + K + Na + WBC + Lymph +
Neut + PLT + EF.TTE + Region.RWMA + VHD + LBBB + RBBB, data = train_data)
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 3.746e+00 3.109e+00 1.205 0.23004
Age 1.101e-02 2.727e-03 4.038 8.12e-05 ***
Weight 2.405e-02 1.909e-02 1.260 0.20926
Length -2.477e-02 1.772e-02 -1.398 0.16387
Sex 1.048e-01 8.037e-02 1.304 0.19404
BMI -7.557e-02 5.038e-02 -1.500 0.13550
DM 2.073e-01 7.236e-02 2.865 0.00469 **
HTN 9.518e-02 6.475e-02 1.470 0.14344
Current.Smoker 9.230e-02 6.804e-02 1.357 0.17670
EX.Smoker 5.140e-02 1.555e-01 0.331 0.74135
FH 1.576e-01 7.188e-02 2.193 0.02969 *
Obesity 3.448e-02 7.600e-02 0.454 0.65062
CRF -9.709e-02 1.743e-01 -0.557 0.57829
CVA 2.075e-02 1.677e-01 0.124 0.90164
Airway.disease 1.343e-01 1.203e-01 1.116 0.26594
Thyroid.Disease 1.038e-03 1.569e-01 0.007 0.99473
CHF 2.907e-01 4.098e-01 0.709 0.47909
DLP -2.869e-02 4.985e-02 -0.575 0.56572
BP 8.179e-04 1.576e-03 0.519 0.60441
PR 6.578e-03 3.182e-03 2.067 0.04023 *
Edema -7.337e-02 1.218e-01 -0.602 0.54766
Weak.Peripheral.Pulse -9.145e-03 1.991e-01 -0.046 0.96341
Lung.rales 3.139e-01 1.519e-01 2.066 0.04030 *
Systolic.Murmur 2.358e-02 8.729e-02 0.270 0.78736
Diastolic.Murmur -1.349e-01 1.521e-01 -0.887 0.37648
Typical.Chest.Pain 2.595e-01 8.304e-02 3.125 0.00209 **
Dyspnea -6.490e-02 5.758e-02 -1.127 0.26130
Function.Class 4.563e-02 2.649e-02 1.723 0.08678 .
Atypical -9.878e-02 8.901e-02 -1.110 0.26863
Nonanginal -1.839e-01 1.249e-01 -1.473 0.14254
LowTH.Ang -1.958e-01 3.737e-01 -0.524 0.60100
Q.Wave -9.562e-02 1.448e-01 -0.660 0.50994
St.Elevation 1.833e-01 1.478e-01 1.240 0.21652
St.Depression -3.279e-03 6.757e-02 -0.049 0.96135
Tinversion 1.028e-01 6.233e-02 1.649 0.10102
LVH 6.016e-02 1.164e-01 0.517 0.60610
Poor.R.Progression 6.398e-02 1.593e-01 0.402 0.68838
FBS 2.994e-05 6.210e-04 0.048 0.96160
CR 2.895e-02 1.131e-01 0.256 0.79830
TG 5.121e-04 2.892e-04 1.771 0.07840 .
LDL -4.419e-04 7.523e-04 -0.587 0.55774
HDL -9.285e-04 2.271e-03 -0.409 0.68317
BUN -7.088e-03 4.086e-03 -1.735 0.08458 .
ESR 1.342e-03 1.845e-03 0.728 0.46781
HB -2.543e-03 1.850e-02 -0.137 0.89080
K 1.047e-01 5.466e-02 1.915 0.05716 .
Na -5.648e-03 6.825e-03 -0.828 0.40905
WBC -8.267e-06 1.321e-05 -0.626 0.53227
Lymph 8.307e-04 6.293e-03 0.132 0.89513
Neut 2.149e-03 6.157e-03 0.349 0.72748
PLT 1.966e-05 4.221e-04 0.047 0.96290
EF.TTE 1.799e-03 3.594e-03 0.501 0.61732
Region.RWMA 4.716e-02 2.526e-02 1.867 0.06361 .
VHD -7.498e-02 4.167e-02 -1.799 0.07372 .
LBBB -1.318e-01 1.249e-01 -1.055 0.29307
RBBB 2.637e-02 1.256e-01 0.210 0.83390
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
(Dispersion parameter for gaussian family taken to be 0.1009441)
Null deviance: 45.956 on 226 degrees of freedom
Residual deviance: 17.261 on 171 degrees of freedom
AIC: 173.34
Number of Fisher Scoring iterations: 2
# make inference on test data
pred_values = predict(logistic_model, test_data, type = "response")
predictions = ifelse(pred_values > 0.5, 1, 0)
# calculate metrics
accuracy = mean(predictions == test_data$Cath)
recall = sum(predictions == 1 & test_data$Cath == 1) / sum(test_data$Cath == 1)
precision = sum(predictions == 1 & test_data$Cath == 1) / sum(predictions == 1)
f1_score = 2 * (precision * recall) / (precision + recall)
# print metrics
print(paste("Accuracy: ", accuracy))[1] "Accuracy: 0.828947368421053"
print(paste("Recall: ", recall))[1] "Recall: 0.830188679245283"
print(paste("Precision: ", precision))[1] "Precision: 0.916666666666667"
print(paste("F1 Score: ", f1_score))[1] "F1 Score: 0.871287128712871"
We can observe that the Frequentist approach is overall less performant than the Bayesian one. We can also see that it gets better results in terms of precision, while as said in the “Goals” section of the report we preferred models with a high recall instead.
Conclusions
We successfully used JAGS for estimation of the model parameters and created two models with satisfying performances.
The third model is the most performant but suffers from some convergence issues which make it less reliable, it is possible that those problems are linked to the high amount of features of the model and could be solved by using a higher amount of iterations, but at the cost of an increased computational time.
Meanwhile the fourth model is much lighter and quickly reached convergence, at the cost of being restricted to a simpler model with slightly lower performances.
References
Alizadehsani, R., Roshanzamir, M., & Sani, Z. (2013). Z-Alizadeh Sani [Dataset]. UCI Machine Learning Repository. https://doi.org/10.24432/C5Q31T.
Alizadehsani, R. et al. A data mining approach for diagnosis of coronary artery disease. Comput. Methods Programs Biomed. 111, 52–61 (2013).
National Heart, Lung and Blood Institute. What is Coronary Heart Disease? https://www.nhlbi.nih.gov/health/coronary-heart-disease
Department of Health, New York State. Heart Disease and Stroke Prevention. https://www.health.ny.gov/diseases/cardiovascular/heart_disease/